#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <iostream>
#include <vector>
#include<assert.h>
#include <algorithm>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>


#define ROWS 1024
#define COLS 1024
#define BLOCKSIZE 32
#define DTYPE half
#define VECTOR_LOAD true

#define Check_Cuda_Runtime(op) __check_cuda_runtime((op),#op,__FILE__,__LINE__);
bool __check_cuda_runtime(cudaError_t code ,const char* op,const char* file,int line){
    
        const char* cuda_error_name = cudaGetErrorName(code);
        const char* cuda_error_string = cudaGetErrorString(code);
        if(code != cudaSuccess){
            printf("cuda error hapepnd at %s:%d,errorname:%s,errorstring:%s,op:%s\n",
                file,line,cuda_error_name,cuda_error_string,op);
            return false;
        }
        return true;
    }

// 通用模板（float 和 int）
template <typename T>
__global__ void elementwise(T* A, T* B, T* C, int total_elements, bool vector_sum) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (vector_sum) {
        // 向量化计算，每个线程处理 4 个元素
        idx *= 4;
        if (idx < total_elements) {
            using Vec4 = typename std::conditional<std::is_same<T, float>::value, float4, int4>::type;
            Vec4 vec_a = *reinterpret_cast<Vec4*>(&A[idx]);
            Vec4 vec_b = *reinterpret_cast<Vec4*>(&B[idx]);
            Vec4 vec_c;
            vec_c.x = vec_a.x + vec_b.x;
            vec_c.y = vec_a.y + vec_b.y;
            vec_c.z = vec_a.z + vec_b.z;
            vec_c.w = vec_a.w + vec_b.w;
            *reinterpret_cast<Vec4*>(&C[idx]) = vec_c;
        }
    } else {
        // 标量计算
        if (idx < total_elements) {
            C[idx] = A[idx] + B[idx];
        }
    }
}

// half 专用版本
template <>
__global__ void elementwise(half* A, half* B, half* C, int total_elements, bool vector_sum) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    if (vector_sum) {
        // 使用 half2 进行向量化计算
        idx *= 2;
        if (idx < total_elements) {
            half2 vec_a = __ldg(reinterpret_cast<half2*>(&A[idx]));
            half2 vec_b = __ldg(reinterpret_cast<half2*>(&B[idx]));
            half2 vec_c = __hadd2(vec_a, vec_b);
            *reinterpret_cast<half2*>(&C[idx]) = vec_c;
        }
    } else {
        // 标量计算
        if (idx < total_elements) {
            C[idx] = __hadd(A[idx], B[idx]);
        }
    }
}
    

template<typename T>
void ptr_data_init(T* p,int total_num){
    for(int i=0;i<total_num;++i){
        p[i] = static_cast<T>(i);
    }    
}

template<>
void ptr_data_init(half* p,int total_num){
    for(int i=0;i<total_num;++i){
        p[i] = __float2half((float)i);
    }    
}


template<typename T>
void launch_cuda_add(T* a,T* b,T* c,int total_nums,int blocksize,bool vector_load){
    // int gridesize = (int)(total_nums + blocksize - 1 )/ blocksize;
    // int gridesize = (int)(total_nums + blocksize - 1 )/ (blocksize*4);
    int gridesize = (int)(total_nums + blocksize - 1 )/ (blocksize*2);
    elementwise<T><<<gridesize,blocksize>>>(a,b,c,total_nums,vector_load);

}

template<typename T>
void launch_cpu_add(T* a,T* b,T* c,int total_nums){
    for(int j=0;j<total_nums;++j){
        c[j] = a[j] + b[j];
    }
}

template<>
void launch_cpu_add(half* a,half* b,half* c,int total_nums){
    for(int j=0;j<total_nums;++j){
        c[j] = __float2half(__half2float(a[j]) + __half2float(b[j]));
    }
}

template<typename T>
void check_cuda_cpu_result(T* a,T* b,int total_nums){
    bool temp = true;
    for(int k=0;k<total_nums;++k){
        if(fabs(a[k] - b[k])>1e-5){
            printf("a: %f b: %f diff: %f\n",(float)a[k],(float)b[k],fabs(a[k] - b[k]));
            temp = false;
            
        }
    }
    if(temp){
        printf("cuda and cpu get the same result!\n");
    }
}

template<>
void check_cuda_cpu_result(half* a,half* b,int total_nums){
    bool temp = true;
    for(int k=0;k<total_nums;++k){
        if(fabs(__half2float(a[k]) - __half2float(b[k]))>1e-5){
            printf("a: %f b: %f diff: %f\n",__half2float(a[k]),__half2float(b[k]),fabs(__half2float(a[k]) - __half2float(b[k])));
            temp = false;
            
        }
    }
    if(temp){
        printf("cuda and cpu get the same result!\n");
    }
}


int main(){
    using  dtype = DTYPE;
    bool vector_load = VECTOR_LOAD;
    int rows = ROWS;
    int cols = COLS;
    int blocksize = BLOCKSIZE;
    int total_nums = rows*cols;
    int total_bytes = rows*cols*sizeof(dtype);
    dtype* ha;
    dtype* hb;
    dtype* hc;
    dtype* hc_;
    ha = (dtype*)malloc(total_bytes);
    hb = (dtype*)malloc(total_bytes);
    hc = (dtype*)malloc(total_bytes);
    hc_ = (dtype*)malloc(total_bytes);
    ptr_data_init(ha,total_nums);
    ptr_data_init(hb,total_nums);
    ptr_data_init(hc,total_nums);
    dtype* da;
    dtype* db;
    dtype* dc;
    Check_Cuda_Runtime(cudaMalloc((void**)&da,total_bytes));
    Check_Cuda_Runtime(cudaMalloc((void**)&db,total_bytes));
    Check_Cuda_Runtime(cudaMalloc((void**)&dc,total_bytes));
    Check_Cuda_Runtime(cudaMemcpy(da,ha,total_bytes,cudaMemcpyHostToDevice));
    Check_Cuda_Runtime(cudaMemcpy(db,hb,total_bytes,cudaMemcpyHostToDevice));
    launch_cuda_add<dtype>(da,db,dc,total_nums,blocksize,vector_load);
    Check_Cuda_Runtime(cudaMemcpy(hc,dc,total_bytes,cudaMemcpyDeviceToHost));
    launch_cpu_add<dtype>(ha,hb,hc_,total_nums);
    check_cuda_cpu_result<dtype>(hc,hc_,total_nums);
        
    free(ha);
    free(hb);
    free(hc);
    cudaFree(da);
    cudaFree(db);
    cudaFree(dc);

    return -1;
}